import jax
import jax.numpy as jnp
import jax.tree_util as jtu
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import genjax
sns.set_theme(style="white")
# Pretty printing.
console = genjax.pretty(width=70)
# Reproducibility.
key = jax.random.PRNGKey(314159)The sequential Monte Carlo mini-language
Abstract
This notebook describes a mini-language for constructing sequential Monte Carlo (SMC) algorithms from composable pieces. The language is based on work by Alex Lew - and an SMC algorithm language built on Prox, a language for compositional approximate densities. GenJAX also implements Prox densities, and provides a modified version of the SMC mini-language - modified to support efficient JAX compilation.